{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch_geometric.data import HeteroData\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.nn import Sequential, Linear\n",
    "from torch.nn import ReLU\n",
    "from torch_geometric.datasets import OGB_MAG\n",
    "from torch_geometric.nn import SAGEConv, to_hetero\n",
    "from torch_geometric.loader import NeighborLoader, HGTLoader\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Random data just to show how store values of nodes work\n",
    "authors = torch.rand((10,8))\n",
    "papers = torch.rand((20,4))\n",
    "authors_y = torch.rand(10)\n",
    "\n",
    "# Random data just to show how store values of edges work\n",
    "write_from = torch.tensor(np.random.choice(10, 50, replace = True))\n",
    "write_to = torch.tensor(np.random.choice(20, 50, replace=True))\n",
    "write = torch.concat((write_from, write_to)).reshape(-1,50).long()\n",
    "\n",
    "# Random dat justo to show how store values of edges work\n",
    "cite_from = torch.tensor(np.random.choice(20, 15, replace=True))\n",
    "cite_to = torch.tensor(np.random.choice(20, 15, replace=True))\n",
    "cite = torch.concat((cite_from, cite_to)).reshape(-1,15).long()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#-------------------------Register HeteroData\n",
    "# Pattern to declare all as one dictionary as argument of class HeteroData\n",
    "data = HeteroData({'author': {'x':authors, 'y':authors_y}, 'paper':{'x':papers}},\n",
    "                 author__write__paper={'edge_index':write}, paper__cite__paper={'edge_index': cite})\n",
    "\n",
    "data.metadata()\n",
    "\n",
    "# Transforms from many types of nodes and edges to just one type of each\n",
    "homogeneus_data = data.to_homogeneous()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# If you want to store the data\n",
    "data.to_dict()\n",
    "\n",
    "#-------------------------Example of model with HeteroData\n",
    "transform = T.RandomNodeSplit()\n",
    "data = transform(data)\n",
    "\n",
    "#---------------------Model 1 \n",
    "class GNN(torch.nn.Module):\n",
    "    def __init__(self, hidden_channels, out_channels):\n",
    "        super().__init__()\n",
    "        self.conv1 = SAGEConv((-1,-1), hidden_channels)\n",
    "        self.conv2 = SAGEConv((-1,-1), out_channels)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index).relu()\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x\n",
    "\n",
    "model = GNN(hidden_channels=64, out_channels=2)\n",
    "model= to_hetero(model, data.metadata(), aggr='sum')\n",
    "\n",
    "##---------------------Model 2\n",
    "model = Sequential('x, edge_index', [\n",
    "    (SAGEConv((-1,1),64), 'x, edge_index ->x'),\n",
    "    ReLU(inplace = True),\n",
    "    (SAGEConv((-1,1),64), 'x, edge_index ->x'),\n",
    "    ReLU(inplace = True),\n",
    "    (Linear(-1,2), 'x -> x'),\n",
    "])\n",
    "\n",
    "model = to_hetero(model, data.metadata(), aggr='sum')\n",
    "\n",
    "#-------------------------Train Data\n",
    "\n",
    "dataset = OGB_MAG(root='.data', preprocess='metapath2vec', transform=T.ToUndirected())\n",
    "data = dataset[0]\n",
    "\n",
    "data.metadata()\n",
    "\n",
    "train_input_nodes = ('paper', data['paper'].train_mask)\n",
    "train_loader = NeighborLoader(data, num_neighbors=[10] *2, shuffle=True, input_nodes=train_input_nodes)\n",
    "\n",
    "for t in train_loader:\n",
    "    print(t)\n",
    "    break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pinss",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
